{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/daiyuxin/anaconda3/envs/mindspore/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "\n",
    "import mindspore\n",
    "from mindspore.dataset import text, GeneratorDataset, transforms\n",
    "from mindspore import nn, context\n",
    "\n",
    "from mindnlp.engine import Trainer, Evaluator\n",
    "from mindnlp.engine.callbacks import CheckpointCallback, BestModelCallback\n",
    "from mindnlp.metrics import Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# prepare dataset\n",
    "class SentimentDataset:\n",
    "    \"\"\"Sentiment Dataset\"\"\"\n",
    "\n",
    "    def __init__(self, path):\n",
    "        self.path = path\n",
    "        self._labels, self._text_a = [], []\n",
    "        self._load()\n",
    "\n",
    "    def _load(self):\n",
    "        with open(self.path, \"r\", encoding=\"utf-8\") as f:\n",
    "            dataset = f.read()\n",
    "        lines = dataset.split(\"\\n\")\n",
    "        for line in lines[1:-1]:\n",
    "            label, text_a = line.split(\"\\t\")\n",
    "            self._labels.append(int(label))\n",
    "            self._text_a.append(text_a)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        return self._labels[index], self._text_a[index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2023-11-30 16:11:36--  https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz\n",
      "Resolving baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)... 36.110.192.178, 2409:8c04:1001:1002:0:ff:b001:368a\n",
      "Connecting to baidu-nlp.bj.bcebos.com (baidu-nlp.bj.bcebos.com)|36.110.192.178|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 1710581 (1.6M) [application/x-gzip]\n",
      "Saving to: ‘emotion_detection.tar.gz’\n",
      "\n",
      "emotion_detection.t 100%[===================>]   1.63M  6.36MB/s    in 0.3s    \n",
      "\n",
      "2023-11-30 16:11:37 (6.36 MB/s) - ‘emotion_detection.tar.gz’ saved [1710581/1710581]\n",
      "\n",
      "data/\n",
      "data/test.tsv\n",
      "data/infer.tsv\n",
      "data/dev.tsv\n",
      "data/train.tsv\n",
      "data/vocab.txt\n"
     ]
    }
   ],
   "source": [
    "# download dataset\n",
    "!wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz\n",
    "!tar xvf emotion_detection.tar.gz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):\n",
    "    column_names = [\"label\", \"text_a\"]\n",
    "    \n",
    "    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)\n",
    "    # transforms\n",
    "    type_cast_op = transforms.TypeCast(mindspore.int32)\n",
    "    def tokenize_and_pad(text):\n",
    "        tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)\n",
    "        return tokenized[0], tokenized[2]\n",
    "    # map dataset\n",
    "    dataset = dataset.map(operations=tokenize_and_pad, input_columns=\"text_a\", output_columns=['input_ids', 'attention_mask'])\n",
    "    dataset = dataset.map(operations=[type_cast_op], input_columns=\"label\", output_columns='labels')\n",
    "    # batch dataset\n",
    "    dataset = dataset.batch(batch_size)\n",
    "\n",
    "    return dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from mindnlp.transformers import BertTokenizer\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset_train = process_dataset(SentimentDataset(\"data/train.tsv\"), tokenizer)\n",
    "dataset_val = process_dataset(SentimentDataset(\"data/dev.tsv\"), tokenizer)\n",
    "dataset_test = process_dataset(SentimentDataset(\"data/test.tsv\"), tokenizer, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['input_ids', 'attention_mask', 'labels']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset_train.get_col_names()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following parameters in checkpoint files are not loaded:\n",
      "['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.gamma', 'cls.predictions.transform.LayerNorm.beta', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
      "The following parameters in models are missing parameter:\n",
      "['classifier.weight', 'classifier.bias']\n"
     ]
    }
   ],
   "source": [
    "from mindnlp.transformers import BertForSequenceClassification, BertModel\n",
    "from mindnlp._legacy.amp import auto_mixed_precision\n",
    "\n",
    "# set bert config and define parameters for training\n",
    "model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)\n",
    "model = auto_mixed_precision(model, 'O1')\n",
    "\n",
    "optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "metric = Accuracy()\n",
    "# define callbacks to save checkpoints\n",
    "ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)\n",
    "best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)\n",
    "\n",
    "trainer = Trainer(network=model, train_dataset=dataset_train,\n",
    "                  eval_dataset=dataset_val, metrics=metric,\n",
    "                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The train will start from the checkpoint saved in 'checkpoint'.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|███████████████████████████████████████| 302/302 [00:38<00:00,  7.85it/s, loss=0.35258844]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: 'bert_emotect_epoch_0.ckpt' has been saved in epoch: 0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|█████████████████████████████████████████████████████████| 34/34 [00:02<00:00, 15.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.9046296296296297}\n",
      "---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 0.---------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1: 100%|████████████████████████████████████████| 302/302 [00:35<00:00,  8.48it/s, loss=0.1903949]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checkpoint: 'bert_emotect_epoch_1.ckpt' has been saved in epoch: 1.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|█████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 17.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.9537037037037037}\n",
      "---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 1.---------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2: 100%|███████████████████████████████████████| 302/302 [00:36<00:00,  8.39it/s, loss=0.13147199]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The maximum number of stored checkpoints has been reached.\n",
      "Checkpoint: 'bert_emotect_epoch_2.ckpt' has been saved in epoch: 2.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|█████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 17.65it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.9601851851851851}\n",
      "---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 2.---------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3: 100%|██████████████████████████████████████| 302/302 [00:35<00:00,  8.39it/s, loss=0.096015364]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The maximum number of stored checkpoints has been reached.\n",
      "Checkpoint: 'bert_emotect_epoch_3.ckpt' has been saved in epoch: 3.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|█████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 17.74it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.9907407407407407}\n",
      "---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 3.---------------\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|███████████████████████████████████████| 302/302 [00:36<00:00,  8.23it/s, loss=0.06472951]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The maximum number of stored checkpoints has been reached.\n",
      "Checkpoint: 'bert_emotect_epoch_4.ckpt' has been saved in epoch: 4.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|█████████████████████████████████████████████████████████| 34/34 [00:01<00:00, 17.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.9953703703703703}\n",
      "---------------Best Model: 'bert_emotect_best.ckpt' has been saved in epoch: 4.---------------\n",
      "Loading best model from 'checkpoint' with '['Accuracy']': [0.9953703703703703]...\n",
      "---------------The model is already load the best model from 'bert_emotect_best.ckpt'.---------------\n"
     ]
    }
   ],
   "source": [
    "# start training\n",
    "trainer.run(tgt_columns=\"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluate: 100%|█████████████████████████████████████████████████████████| 33/33 [00:01<00:00, 17.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate Score: {'Accuracy': 0.8841698841698842}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)\n",
    "evaluator.run(tgt_columns=\"labels\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset_infer = SentimentDataset(\"data/infer.tsv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def predict(text, label=None):\n",
    "    label_map = {0: \"消极\", 1: \"中性\", 2: \"积极\"}\n",
    "\n",
    "    text_tokenized = Tensor([tokenizer(text).input_ids])\n",
    "    logits = model(text_tokenized)\n",
    "    predict_label = logits[0].asnumpy().argmax()\n",
    "    info = f\"inputs: '{text}', predict: '{label_map[predict_label]}'\"\n",
    "    if label is not None:\n",
    "        info += f\" , label: '{label_map[label]}'\"\n",
    "    print(info)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs: '我 要 客观', predict: '中性' , label: '中性'\n",
      "inputs: '靠 你 真是 说 废话 吗', predict: '消极' , label: '消极'\n",
      "inputs: '口嗅 会', predict: '中性' , label: '中性'\n",
      "inputs: '每次 是 表妹 带 窝 飞 因为 窝路痴', predict: '中性' , label: '中性'\n",
      "inputs: '别说 废话 我 问 你 个 问题', predict: '消极' , label: '消极'\n",
      "inputs: '4967 是 新加坡 那 家 银行', predict: '中性' , label: '中性'\n",
      "inputs: '是 我 喜欢 兔子', predict: '积极' , label: '积极'\n",
      "inputs: '你 写 过 黄山 奇石 吗', predict: '中性' , label: '中性'\n",
      "inputs: '一个一个 慢慢来', predict: '中性' , label: '中性'\n",
      "inputs: '我 玩 过 这个 一点 都 不 好玩', predict: '消极' , label: '消极'\n",
      "inputs: '网上 开发 女孩 的 QQ', predict: '中性' , label: '中性'\n",
      "inputs: '背 你 猜 对 了', predict: '中性' , label: '中性'\n",
      "inputs: '我 讨厌 你 ， 哼哼 哼 。 。', predict: '消极' , label: '消极'\n"
     ]
    }
   ],
   "source": [
    "from mindspore import Tensor\n",
    "\n",
    "for label, text in dataset_infer:\n",
    "    predict(text, label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "inputs: '家人们咱就是说一整个无语住了 绝绝子叠buff', predict: '中性'\n"
     ]
    }
   ],
   "source": [
    "predict(\"家人们咱就是说一整个无语住了 绝绝子叠buff\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
